Skip to content

Conversation

@danisereb
Copy link

@danisereb danisereb commented Jan 5, 2026

What does this PR do?

Type of change: new feature

Overview: Add support for MXFP8 PTQ, enabling MXFP8 hardware acceleration during inference on Blackwell GPUs.

Usage

export MODEL_PATH=/my_home/hf_models/nvidia/OpenMath2-Llama3.1-8B
export OUTPUT_PATH=/my_home/hf_models/nvidia/OpenMath2-Llama3.1-8B-MXFP8
mkdir -p $OUTPUT_PATH

python examples/llm_ptq/hf_ptq.py \
--export_fmt hf \
--dataset cnn_dailymail \
--pyt_ckpt_path $MODEL_PATH \
--export_path $OUTPUT_PATH \
--qformat mxfp8

The hf_quant_config.json of the output checkpoint:

{
    "producer": {
        "name": "modelopt",
        "version": "0.41.0.dev50+g7a796a875"
    },
    "quantization": {
        "quant_algo": "MXFP8",
        "kv_cache_quant_algo": "FP8",
        "group_size": 32,
        "exclude_modules": [
            "lm_head"
        ]
    }
}

And config.json (only the quantization_config):

...
    "quantization_config": {
        "ignore": [
            "lm_head"
        ],
        "quant_algo": "MXFP8",
        "kv_cache_scheme": {
            "dynamic": false,
            "num_bits": 8,
            "type": "float"
        },
        "producer": {
            "name": "modelopt",
            "version": "0.41.0.dev50+g7a796a875"
        },
        "quant_method": "modelopt"
    }

Testing

Used hf_ptq.py to quantize the model nvidia/OpenMath2-Llama3.1-8B (available in hugging-face), see the example command above.

Checked that the generated MXFP8 checkpoint can be loaded with vLLM (required changes in vLLM, not merged to main).

Added tests for MXFP8QTensor in tests/gpu/torch/quantization/test_qtensor_cuda.py.
Added "mxfp8" in ‎tests/examples/llm_ptq/test_llm_ptq.py

Support for Nemotron Models

Verify that Nemotron Nano V3 BF16 can be converted to MXFP8 using hf_ptq.py:
https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added MXFP8 quantization format support with new scaling mechanisms and quantization utilities.
    • Updated configuration options, example scripts, and utilities to recognize and process MXFP8 quantization workflows.
    • Extended quantization export pipelines to handle MXFP8 quantized models.
  • Tests

    • Expanded test coverage for MXFP8 quantization across various tensor shapes, data types, and device configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@danisereb danisereb marked this pull request as ready for review January 6, 2026 12:31
@danisereb danisereb requested review from a team as code owners January 6, 2026 12:31
@danisereb danisereb requested review from mxinO and sugunav14 January 6, 2026 12:31
@danisereb danisereb force-pushed the support_mxfp8 branch 2 times, most recently from 16f12fa to 88b6869 Compare January 6, 2026 13:18
@danisereb danisereb requested a review from meenchen January 6, 2026 13:49
@sugunav14
Copy link
Contributor

Could you also add the corresponding unit tests for impacted functions in quant_utils.py here? Thanks!

# Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent)
scale_factor = torch.exp2(127 - e8m0_scale.float())

# NOTE: vLLM/flashinfer may require this behavior:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this required? Should we assert e8m0_scale != 0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIU, it doesn't align with MXFP8 specification.
But one of my teammates said that it worked for him in a certain case.

So I wanted to leave some documentation for it for future reference.

# sm89
PTQCommand(quant="fp8", min_sm=89),
PTQCommand(quant="fp8", kv_cache_quant="none", min_sm=89), # sm100
PTQCommand(quant="mxfp8", min_sm=100),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does hopper support mxfp8?

Copy link
Author

@danisereb danisereb Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blackwell has hardware acceleration for MXFP8.
Hopper does not.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#MXFP8-and-block-scaling

NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: MXFP8.

See what we have for NVFP4 (line below the "mxfp8"):

PTQCommand(quant="nvfp4", min_sm=100),

@codecov
Copy link

codecov bot commented Jan 6, 2026

Codecov Report

❌ Patch coverage is 25.00000% with 60 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.02%. Comparing base (406c18c) to head (626f18e).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...odelopt/torch/quantization/qtensor/mxfp8_tensor.py 25.33% 56 Missing ⚠️
.../torch/quantization/nn/modules/tensor_quantizer.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #736      +/-   ##
==========================================
- Coverage   74.23%   74.02%   -0.21%     
==========================================
  Files         192      193       +1     
  Lines       19033    19113      +80     
==========================================
+ Hits        14129    14149      +20     
- Misses       4904     4964      +60     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

assert dequant_tensor.shape == input_shape, (
f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}"
)
assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also compare with the fake quant here.

Copy link
Author

@danisereb danisereb Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, fake quant is tested by test_qtensor_accuracy (part of class TestQTensor).
In the code below the comment # compare with fake quant as well.

I added a test case for MXFP8 in test test_qtensor_accuracy.
And checked that it works using this command:

pytest --maxfail 1 tests/gpu/torch/quantization/test_qtensor_cuda.py -k "test_qtensor_accuracy"

All the new MXFP8 tests also worked, using this command:

pytest tests/gpu/torch/quantization/test_qtensor_cuda.py -k "test_mxfp8"

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

This change introduces support for MXFP8 quantization format across the codebase. A new MXFP8QTensor class implements block-based FP8 E4M3 quantization with E8M0 shared scales. MXFP8 support is integrated into configuration, quantization utilities, export workflows, and test coverage.

Changes

Cohort / File(s) Summary
New MXFP8 Tensor Quantization Core
modelopt/torch/quantization/qtensor/mxfp8_tensor.py
New MXFP8QTensor class with quantization/dequantization, scale computation, E8M0 exponent handling, and support for multi-dimensional tensors. Block size of 32 elements on last dimension.
MXFP8 Module Exports
modelopt/torch/quantization/qtensor/__init__.py
Re-exports MXFP8QTensor from mxfp8_tensor module into public qtensor API.
Quantizer Integration
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Adds MXFP8QTensor support in real quantization path for (4, 3) block configuration, validated against BLOCK_SIZE.
Quantization Configuration & Constants
modelopt/torch/export/model_config.py
Adds QUANTIZATION_MXFP8 constant definition.
Quantization Utilities
modelopt/torch/export/quant_utils.py
Extends weight scaling factor computation, layer quantization detection, layer config processing, and weight quantization to recognize and handle MXFP8 format.
HuggingFace Export
modelopt/torch/export/unified_export_hf.py
Adds MXFP8QTensor import and QUANTIZATION_MXFP8 support in \_export_quantized_weight for E8M0 scale registration.
Configuration & Examples
examples/llm_ptq/hf_ptq.py, examples/llm_ptq/scripts/huggingface_example.sh
Adds "mxfp8" to QUANT_CFG_CHOICES mapping and shell script validation for valid quantization formats.
Test Coverage
tests/examples/llm_ptq/test_llm_ptq.py
Adds MXFP8 quantization test case with min_sm=100 requirement to existing PTQ test scenarios.
Comprehensive MXFP8 Tests
tests/gpu/torch/quantization/test_qtensor_cuda.py
Extensive test suite for MXFP8QTensor covering quantization/dequantization, scale handling, shape/dtype assertions, boundary values, error cases, and multi-device/dtype scenarios.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title 'Add support for MXFP8 PTQ' accurately and concisely describes the main objective of the changeset, which is to introduce MXFP8 post-training quantization support across multiple modules and examples.
Docstring Coverage ✅ Passed Docstring coverage is 85.19% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/qtensor/mxfp8_tensor.py`:
- Around line 93-138: get_weights_scaling_factor_from_quantizer currently
assumes 2D weights and computes expected_shape = (out_dim, in_dim //
BLOCK_SIZE), which breaks for 3D MoE weights (num_experts, out_dim, in_dim)
because reduce_block_amax yields a 3D scale; update the method to detect MoE by
checking weight.dim() == 3 and set expected_shape = (num_experts, out_dim,
in_dim // cls.BLOCK_SIZE) in that case (or mirror the NVFP4 transpose guard
behavior before calling this method), then after pulling weight_quantizer._scale
ensure scale.shape exactly equals expected_shape (after an allowed reshape only
when numel matches) and raise/assert with a clear message if it does not;
reference symbols: get_weights_scaling_factor_from_quantizer,
get_weights_scaling_factor, cls.BLOCK_SIZE, cls.SCALE_DTYPE, and
weight_quantizer._scale.
🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_hf.py (1)

301-309: Minor redundancy: weight is already available from line 250.

The MXFP8 export logic is correct. However, weight is already fetched at line 250 via getattr(sub_module, weight_name), so line 303 re-fetches the same value unnecessarily.

♻️ Optional: reuse existing weight variable
         elif quantization_format == QUANTIZATION_MXFP8:
             # MXFP8 uses dynamic block quantization with E8M0 scales (uint8)
-            weight = getattr(sub_module, weight_name)
             e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(
                 weight, weight_quantizer
             )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0f05d67 and a764b32.

📒 Files selected for processing (10)
  • examples/llm_ptq/hf_ptq.py
  • examples/llm_ptq/scripts/huggingface_example.sh
  • modelopt/torch/export/model_config.py
  • modelopt/torch/export/quant_utils.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/qtensor/__init__.py
  • modelopt/torch/quantization/qtensor/mxfp8_tensor.py
  • tests/examples/llm_ptq/test_llm_ptq.py
  • tests/gpu/torch/quantization/test_qtensor_cuda.py
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (2)
  • MXFP8QTensor (26-269)
  • quantize (195-222)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (2)
  • MXFP8QTensor (26-269)
  • get_weights_scaling_factor_from_quantizer (94-138)
tests/examples/llm_ptq/test_llm_ptq.py (2)
tests/_test_utils/examples/llm_ptq_utils.py (1)
  • PTQCommand (28-87)
tests/_test_utils/torch/quantization/quant_utils.py (1)
  • quant (19-30)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (3)
  • MXFP8QTensor (26-269)
  • get_weights_scaling_factor_from_quantizer (94-138)
  • quantize_with_scale (141-192)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • block_sizes (360-362)
  • block_sizes (365-367)
🔇 Additional comments (32)
examples/llm_ptq/scripts/huggingface_example.sh (2)

56-58: LGTM!

The mxfp8 format is correctly added to the valid quantization formats list and the error message is updated consistently.


207-210: Verify if mxfp8 should be added to the supported formats for TRT-LLM torch runtime.

The mxfp8 format is not included in the check on line 207, meaning MXFP8-quantized models will exit early with a message to use TensorRT-LLM for deployment. If MXFP8 should support the same workflow as fp8 and nvfp4 (continuing to run_tensorrt_llm.py), consider adding it:

-    if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then
+    if [[ ! " fp8 nvfp4 bf16 fp16 mxfp8 " =~ " ${QFORMAT} " ]]; then
modelopt/torch/export/model_config.py (1)

38-38: LGTM!

The new QUANTIZATION_MXFP8 constant follows the established naming convention and is correctly placed among related quantization format identifiers.

tests/examples/llm_ptq/test_llm_ptq.py (1)

117-117: LGTM!

The MXFP8 test case is correctly configured with min_sm=100 to ensure it only runs on Blackwell GPUs which have hardware acceleration for MXFP8.

modelopt/torch/quantization/qtensor/__init__.py (1)

23-23: LGTM!

The mxfp8_tensor module export follows the established pattern and is correctly positioned alphabetically among the other tensor module imports.

examples/llm_ptq/hf_ptq.py (3)

175-191: LGTM!

The mxfp8 format is correctly added to the auto-quantize validation list, enabling MXFP8 as a valid format option for automatic per-layer quantization search.


759-774: LGTM!

The mxfp8 format is correctly added to the mono-quantize validation list for the HF export path.


86-86: LGTM!

The mxfp8 format is correctly mapped to mtq.MXFP8_DEFAULT_CFG in the quantization configuration choices dictionary. The constant is properly defined and exported in the mtq module.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

52-52: LGTM!

The import addition for MXFP8QTensor is correctly placed alongside other quantized tensor imports.


693-703: LGTM!

The MXFP8 branch correctly:

  • Validates block size matches the MXFP8 spec (32)
  • Uses MXFP8QTensor.quantize() which handles block-based quantization internally
  • Stores scales in the same manner as MXFP4

The distinction between MXFP4 (2, 1) and MXFP8 (4, 3) num_bits is clear and properly separated.

tests/gpu/torch/quantization/test_qtensor_cuda.py (9)

18-18: LGTM!

The imports for math and MXFP8QTensor are correctly added to support the new MXFP8 tests.

Also applies to: 27-27


253-260: LGTM!

The MXFP8 test case is correctly added to test_qtensor_accuracy with appropriate configuration matching the MXFP8 spec (block size 32, dynamic type, scale_bits (8,0)).


616-676: LGTM!

Comprehensive test for MXFP8 quantize/dequantize covering:

  • Multiple devices (cuda, cpu)
  • Multiple dtypes (float32, float16, bfloat16)
  • Various shapes including 3D MoE-like tensors
  • Padding scenarios (dimensions not divisible by 32)
  • Proper assertions for scale dtype, shapes, and quantized data format

The tolerance of rtol=5e-2, atol=5e-2 is reasonable for FP8 quantization precision.


678-716: LGTM!

Excellent test for verifying E8M0 scale computation with known input values. The test validates that per-block max values are preserved through the quantize-dequantize cycle.


718-751: LGTM!

Good boundary value testing for FP8 E4M3 limits (max 448, powers of 2, positive/negative values). The # fmt: off/on markers appropriately preserve the readable tensor formatting.


753-782: LGTM!

Memory usage test follows the same pattern as the existing NVFP4 test. The 3x threshold is reasonable given MXFP8 stores FP8 data plus uint8 scales.


784-806: LGTM!

Tests for get_weights_scaling_factor with proper shape and dtype validation. The check for E8M0 values ≤ 254 correctly excludes NaN representation (255).


808-824: LGTM!

Good coverage of edge cases for _compute_e8m0_exponent:

  • Zero amax → minimum exponent (-127)
  • E4M3_MAX (448) → exponent 0
  • Normal value (1.0) → computed exponent
  • Very large/small values → clamped to valid range

826-889: LGTM!

Comprehensive error handling tests covering:

  • 1D tensor assertions
  • Non-divisible dimensions
  • Wrong scale dtype
  • Empty tensor
  • 0D tensor (scalar)
  • Non-floating point input
  • Missing scale in dequantize

This ensures robust input validation.

modelopt/torch/export/unified_export_hf.py (1)

35-35: LGTM!

The imports for MXFP8QTensor and QUANTIZATION_MXFP8 are correctly added to support MXFP8 export handling.

Also applies to: 54-54

modelopt/torch/export/quant_utils.py (5)

33-33: LGTM!

The imports for MXFP8QTensor and QUANTIZATION_MXFP8 are correctly added.

Also applies to: 58-58


296-297: LGTM!

The MXFP8 weight scaling factor retrieval correctly delegates to MXFP8QTensor.get_weights_scaling_factor_from_quantizer, which handles both extracting existing scales and computing new ones.


482-489: LGTM!

The MXFP8 detection logic correctly identifies the format by checking:

  • block_sizes is a dict
  • type is "dynamic"
  • scale_bits is (8, 0) (E8M0 format)

This is properly positioned before the FP8_PB_WO/FP8_PB_REAL checks at lines 490-493, ensuring MXFP8 is correctly distinguished from other FP8 block quantization formats.


685-689: LGTM!

The MXFP8 layer config processing correctly maps the "mxfp8" format to "MXFP8" quant_algo with the appropriate group_size, following the same pattern as other quantization formats.


794-795: LGTM!

The to_quantized_weight function correctly uses MXFP8QTensor.quantize_with_scale to apply the pre-computed E8M0 scale to the weight tensor.

modelopt/torch/quantization/qtensor/mxfp8_tensor.py (7)

1-23: LGTM!

Clean module structure with proper license, docstring, imports from existing utilities (reduce_block_amax, reduce_block_padding), and explicit __all__ export.


26-40: LGTM!

Class constants are correctly defined:

  • E4M3_MAX = 448.0 matches FP8 E4M3 max value
  • BLOCK_SIZE = 32 per MXFP8 specification
  • SCALE_DTYPE = torch.uint8 for E8M0 biased exponent storage

42-66: LGTM!

The _compute_e8m0_exponent implementation:

  • Converts to float32 for numerical stability
  • Handles zero values by using torch.where with min_value fallback
  • Correctly computes ceil(log2(amax / E4M3_MAX))
  • Clamps to valid E8M0 range [-127, 127]

68-91: LGTM!

The get_weights_scaling_factor implementation correctly:

  • Validates 2D minimum dimension
  • Validates divisibility by BLOCK_SIZE
  • Uses existing reduce_block_amax utility
  • Converts to biased uint8 format (exponent + 127)

140-192: LGTM!

The quantize_with_scale implementation is well-structured:

  • Proper input validation for dimensions and dtype
  • Flexible scale reshaping to handle different input shapes
  • Correct scale factor computation: 2^(127 - exponent)
  • Proper clamping to E4M3 range before FP8 conversion
  • The NOTE comment documents potential vLLM/flashinfer compatibility consideration

194-222: LGTM!

The quantize method correctly implements the full quantization flow:

  • Input validation for empty, dimension, and dtype
  • Padding alignment via reduce_block_padding
  • Per-block amax computation
  • E8M0 exponent computation and biasing
  • Shape restoration via cropping

224-269: LGTM!

The dequantize method correctly reverses the quantization:

  • Requires scale in kwargs (enforced by assertion)
  • Converts quantized data to float for computation
  • Applies padding for block alignment
  • Computes descale as 2^(exponent - 127)
  • Handles scale shape broadcasting
  • Restores original shape via cropping

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Tested by test_qtensor_accuracy.

Signed-off-by: Daniel Serebrenik <[email protected]>
@danisereb danisereb requested a review from mxinO January 15, 2026 19:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants